# -*- coding: utf-8 -*-
# @author: HRUN

import os
import sys

# 添加当前目录到Python路径
current_dir = os.path.dirname(os.path.abspath(__file__))
backend_dir = os.path.dirname(current_dir)
if backend_dir not in sys.path:
    sys.path.insert(0, backend_dir)


from django_init_simple import setup_django, get_models
from db_manager import safe_get_env_config, safe_save_log


# 初始化Django
setup_django()

import copy
import importlib
import json
import subprocess
from collections import OrderedDict
import re
import os
import sys
from numbers import Number

from django.utils import timezone
from locust import FastHttpUser, task, TaskSet, between, events, LoadTestShape, HttpUser, constant
from django_init_simple import setup_django as setup_django_simple
from locust.env import Environment
from gevent.lock import Semaphore

from requests_toolbelt import MultipartEncoder
import random # Added for on_user_start/stop

global ENV, global_func, global_func_file, current_report_id

def setup_django():
    """初始化Django设置，避免重复设置"""
    try:
        return setup_django_simple()
    except ImportError as e:
        print(f"导入django_init_simple失败: {e}")
        raise


class RealTimeMonitor:
    """实时监控类"""

    def __init__(self, report_id=None):
        self.report_id = report_id
        self.stats_data = []
        self.last_update_time = timezone.now()
        self.update_interval = 5  # 更新间隔到5秒，记录更多日志
        self.current_environment = None
        self.stats_history = []  # 存储历史统计数据
        self.detailed_stats = {}  # 存储详细的接口统计数据
        self.start_time = timezone.now()  # 记录开始时间
        self.request_count = 0  # 请求计数器
        self.log_sample_rate = 0.1  # 日志采样率，默认10%（已废弃，现在使用固定间隔1000）
        self.gui_url = None
        
    def setup(self):
        """设置监控器，注册所有事件监听器"""
        # 注册Locust事件监听器
        events.test_start.add_listener(self.on_test_start)
        events.test_stop.add_listener(self.on_test_stop)
        # 使用统一的request事件，而不是分开的success和failure事件
        events.request.add_listener(self.on_request)
        events.user_error.add_listener(self.on_user_error)
        events.spawning_complete.add_listener(self.on_spawning_complete)
        events.quitting.add_listener(self.on_quitting)
        
        # 高级事件监听
        if hasattr(events, 'user_start'):
            events.user_start.add_listener(self.on_user_start)
        if hasattr(events, 'user_stop'):
            events.user_stop.add_listener(self.on_user_stop)
        if hasattr(events, 'reset_stats'):
            events.reset_stats.add_listener(self.on_reset_stats)
        
    def on_test_start(self, environment, **kwargs):
        """测试开始时的处理"""
        print(f"性能测试开始 - 报告ID: {self.report_id}")
        self.current_environment = environment
        self.stats_history = []
        self.detailed_stats = {}
        self.start_time = timezone.now()  # 更新开始时间
        
        # 记录测试开始日志
        if self.report_id:
            self.log_info("性能测试监控已初始化，准备开始测试")
            if hasattr(environment, 'parsed_options'):
                users = getattr(environment.parsed_options, 'num_users', 0)
                spawn_rate = getattr(environment.parsed_options, 'spawn_rate', 0)
                self.log_info(f"目标用户数: {users}, 启动速率: {spawn_rate}/s")
                
                # 记录更多配置信息
                if hasattr(environment.parsed_options, 'run_time'):
                    run_time = getattr(environment.parsed_options, 'run_time', 'N/A')
                    self.log_info(f"计划运行时间: {run_time}/s")
                
                # 记录主机信息
                self.log_info(f"测试主机: {environment.host or '未指定'}")
                
            # 记录环境信息
            self.log_system_info()
            
    def on_test_stop(self, environment, **kwargs):
        """测试结束时的处理"""
        print(f"性能测试结束 - 报告ID: {self.report_id}")
        
        # 记录测试结束日志
        if self.report_id:
            # 计算测试持续时间
            duration = (timezone.now() - self.start_time).total_seconds()
            duration_str = f"{int(duration // 3600)}小时{int((duration % 3600) // 60)}分{int(duration % 60)}秒"
            
            # 获取最终统计数据
            if environment and hasattr(environment, 'stats'):
                stats = environment.stats
                total_requests = stats.total.num_requests
                total_failures = stats.total.num_failures
                success_rate = ((total_requests - total_failures) / total_requests * 100) if total_requests > 0 else 0
                
                self.log_info(f"性能测试结束 - 持续时间: {duration_str}, 总请求数: {total_requests}, 成功率: {success_rate:.2f}%")
                
                # 记录详细统计
                self.log_info(f"最终统计: RPS: {stats.total.current_rps:.2f}, 平均响应时间: {stats.total.avg_response_time:.2f}ms")
                
                # 记录每个接口的统计数据
                for name, stat in stats.entries.items():
                    if name != "Aggregated":
                        self.log_info(f"接口 [{name}] - 请求数: {stat.num_requests}, 失败数: {stat.num_failures}, "
                                    f"平均响应时间: {stat.avg_response_time:.2f}ms, RPS: {stat.current_rps:.2f}")
                
                # 收集最终统计数据
                final_stats = self.collect_current_stats(environment)
                
                # 调用finalize_report来更新报告状态
                try:
                    from performanceengine.taskResult import finalize_report
                    finalize_success = finalize_report(self.report_id, True, final_stats)
                    if finalize_success:
                        self.log_info("报告状态已更新为已完成")
                    else:
                        self.log_warning("报告状态更新失败")
                except Exception as e:
                    self.log_error(f"更新报告状态失败: {e}")
            else:
                self.log_info(f"性能测试结束 - 持续时间: {duration_str}")
                
                # 没有统计数据，可能是异常终止
                try:
                    from performanceengine.taskResult import finalize_report
                    finalize_report(self.report_id, False, {})
                    self.log_warning("测试异常终止，报告状态已更新")
                except Exception as e:
                    self.log_error(f"更新报告状态失败: {e}")
                
    def on_request_success(self, request_type, name, response_time, response_length, **kwargs):
        """请求成功的处理"""
        self.request_count += 1
        
        # 采样记录日志，避免日志过多（与on_request方法保持一致）
        if self.report_id and (self.request_count % 1000 == 0):
                        
            # 记录慢请求警告
            if response_time > 1000:  # 超过1秒
                self.log_warning(f"高响应时间: {request_type}{name} - {response_time:.2f}ms",
                               request_name=name,
                               response_time=response_time)
                               
    def on_request_failure(self, request_type, name, response_time, response_length, exception, **kwargs):
        """请求失败的处理"""
        self.request_count += 1
        
        # 采样记录失败请求，避免日志过多（与成功请求保持一致）
        if self.report_id and (self.request_count % 1000 == 0):
            # 获取额外信息
            url = kwargs.get('url', '')
            context = kwargs.get('context', {})
            response = kwargs.get('response', None)
            
            # 记录请求失败日志
            status_code = response.status_code if response and hasattr(response, 'status_code') else 'N/A'
            self.log_error(f"请求失败: {request_type} {name} - 响应时间: {response_time:.2f}ms, 状态码: {status_code}, 异常: {exception}", 
                         request_name=name, 
                         response_time=response_time,
                         exception=str(exception))
                         
    def on_user_error(self, user_instance, exception, tb_info, **kwargs):
        """用户错误的处理"""
        if self.report_id:
            self.log_error(f"用户错误: {user_instance.__class__.__name__} - {exception}",
                         exception=str(exception))
                         
    def on_spawning_complete(self, user_count, **kwargs):
        """用户生成完成的处理"""
        if self.report_id:
            self.log_info(f"用户生成完成: {user_count}个用户已准备就绪")
            
    def on_quitting(self, environment, **kwargs):
        """测试退出的处理"""
        if self.report_id:
            
            # 确保在测试退出时也能更新报告状态
            # 这是一个额外的保障措施，以防on_test_stop没有被正确调用
            try:
                # 检查报告状态是否已经更新
                from performance.models import TaskReport
                report = TaskReport.objects.get(id=self.report_id)
                
                # 只有当报告仍处于运行中状态时才更新
                if report.reportStatus == '1':
                    self.log_warning("测试退出但报告状态未更新，尝试更新状态...")
                    
                    # 如果环境对象可用，收集最终统计数据
                    if environment and hasattr(environment, 'stats'):
                        final_stats = self.collect_current_stats(environment)
                        
                        # 调用finalize_report更新报告状态
                        from performanceengine.taskResult import finalize_report
                        finalize_success = finalize_report(self.report_id, True, final_stats)
                        if finalize_success:
                            self.log_info("报告状态已在退出时更新")
                    else:
                        # 没有统计数据，可能是异常退出
                        from performanceengine.taskResult import finalize_report
                        finalize_report(self.report_id, False, {})
                        self.log_warning("测试异常退出，报告状态已更新")
            except Exception as e:
                self.log_error(f"退出时更新报告状态失败: {e}")
            
    def on_user_start(self, user_instance, **kwargs):
        """用户开始的处理"""
        # 采样记录
        if self.report_id and random.random() < 0.1:  # 只记录10%的用户开始事件
            self.log_debug(f"用户开始: {user_instance.__class__.__name__}")
            
    def on_user_stop(self, user_instance, **kwargs):
        """用户停止的处理"""
        # 采样记录
        if self.report_id and random.random() < 0.1:  # 只记录10%的用户停止事件
            self.log_debug(f"用户停止: {user_instance.__class__.__name__}")
            
    def on_reset_stats(self, **kwargs):
        """统计重置的处理"""
        if self.report_id:
            self.log_info("统计数据已重置")
            
    def log_system_info(self):
        """记录系统信息"""
        if not self.report_id:
            return
            
        try:
            import platform
            import psutil
            
            # 系统信息
            system_info = {
                "系统": platform.system(),
                "版本": platform.version(),
                "处理器": platform.processor(),
                "CPU核心数": psutil.cpu_count(logical=True),
                "物理CPU核心数": psutil.cpu_count(logical=False),
                "总内存": f"{psutil.virtual_memory().total / (1024*1024*1024):.2f} GB"
            }
            
            info_str = ", ".join([f"{k}: {v}" for k, v in system_info.items()])
            self.log_info(f"系统信息: {info_str}")
            
            # 当前资源使用情况
            cpu_percent = psutil.cpu_percent(interval=0.1)
            memory_percent = psutil.virtual_memory().percent
            self.log_info(f"当前资源使用: CPU: {cpu_percent}%, 内存: {memory_percent}%")
            
        except Exception as e:
            print(f"记录系统信息失败: {e}")

    def on_request(self, request_type, name, response_time, response_length, response,
                  context, exception, start_time, url, **kwargs):
        """每次请求后的处理"""
        self.request_count += 1

        # 记录请求数据
        self.stats_data.append({
            'request_type': request_type,
            'name': name,
            'response_time': response_time,
            'response_length': response_length,
            'exception': exception,
            'start_time': start_time,
            'url': url,
            'timestamp': timezone.now()
        })

        # 记录日志
        if self.report_id:
            try:
                # 如果有异常，记录错误日志
                if exception:
                    # 采样记录失败请求，避免日志过多（与成功请求保持一致）
                    if self.request_count % 1000 == 0:
                        status_code = response.status_code if (response is not None and hasattr(response, 'status_code')) else 'N/A'
                        self.log_error(f"请求失败: {request_type} {name} - 响应时间: {response_time:.2f}ms, 状态码: {status_code}, 异常: {exception}",
                                     request_name=name,
                                     response_time=response_time,
                                     exception=str(exception))
                else:
                    # 采样记录成功请求，避免日志过多
                    if self.request_count % 1000 == 0:
                        status_code = response.status_code if (response is not None and hasattr(response, 'status_code')) else 200
                        self.log_info(f"请求成功: {request_type} {name} - 响应时间: {response_time:.2f}ms, 状态码: {status_code}",
                                    request_name=name,
                                    response_time=response_time)

                    # 如果响应时间过高，记录警告
                    if response_time > 1000:  # 超过1秒
                        self.log_warning(f"高响应时间: {name} - {response_time:.2f}ms",
                                       request_name=name,
                                       response_time=response_time)
            except Exception as e:
                print(f"记录日志失败: {e}")

        # 定期更新报告（避免频繁更新）
        now = timezone.now()
        if (now - self.last_update_time).total_seconds() >= self.update_interval:
            try:
                if self.current_environment:
                    self.update_report_realtime()
                    self.last_update_time = now
            except Exception as e:
                print(f"实时更新报告失败: {e}")
                
    # 添加日志级别方法
    def log_debug(self, message, **kwargs):
        """记录调试日志"""
        self.log_message('debug', message, **kwargs)
        
    def log_info(self, message, **kwargs):
        """记录信息日志"""
        self.log_message('info', message, **kwargs)
        
    def log_warning(self, message, **kwargs):
        """记录警告日志"""
        self.log_message('warning', message, **kwargs)
        
    def log_error(self, message, **kwargs):
        """记录错误日志"""
        self.log_message('error', message, **kwargs)
        
    def log_critical(self, message, **kwargs):
        """记录严重错误日志"""
        self.log_message('error', f"CRITICAL: {message}", **kwargs)  # 使用error级别，但标记为CRITICAL

    def _get_safe_user_count(self, environment):
        """安全地获取用户数，包含类型检查和错误处理"""
        try:
            if environment and hasattr(environment, 'runner'):
                if hasattr(environment.runner, 'user_count'):
                    raw_user_count = environment.runner.user_count
                    if raw_user_count is not None:
                        return int(raw_user_count)
        except (ValueError, TypeError, AttributeError):
            pass
        return None

    def log_message(self, level, message, **kwargs):
        """记录日志消息"""
        if not self.report_id:
            return
        
        try:
            # 使用统一的Django设置函数
            setup_django()
            
            # 使用安全的用户数获取方法
            user_count = self._get_safe_user_count(self.current_environment)
            
            # 使用线程安全的日志保存
            log = safe_save_log(
                self.report_id, 
                level, 
                message, 
                user_count=user_count,
                request_name=kwargs.get('request_name'),
                response_time=kwargs.get('response_time'),
                exception=kwargs.get('exception')
            )
            # 若保存失败，直接返回，避免对 None 取属性
            if not log:
                return
            
            # 通过WebSocket发送日志更新
            try:
                from channels.layers import get_channel_layer
                from asgiref.sync import async_to_sync
                
                # 构建日志数据
                log_data = None
                if log:
                    log_data = {
                        'id': log.id,
                        'timestamp': log.timestamp.isoformat(),
                        'level': log.level,
                        'message': log.message,
                        'source': log.source,
                        'user_count': log.user_count,
                        'request_name': log.request_name,
                        'response_time': log.response_time,
                        'exception': log.exception
                    }
                
                # 发送到报告频道
                channel_layer = get_channel_layer()
                if channel_layer and log_data:
                    report_group_name = f"performance_report_{self.report_id}"
                    async_to_sync(channel_layer.group_send)(
                        report_group_name,
                        {
                            "type": "log_update",
                            "data": log_data,
                            "timestamp": timezone.now().isoformat()
                        }
                    )
            except Exception as e:
                print(f"发送日志更新失败: {e}")
            
        except Exception as e:
            print(f"保存日志失败: {e}")

    def update_report_realtime(self):
        """实时更新报告"""
        if not self.report_id or not self.current_environment:
            return
            
        try:
            # 动态导入避免循环导入
            from performanceengine.taskResult import update_report_with_results
            
            # 收集当前真实的统计数据
            current_stats = self.collect_current_stats(self.current_environment)
            
            # 添加到历史记录
            self.stats_history.append(current_stats)
            
            # 保持最新100条记录
            if len(self.stats_history) > 100:
                self.stats_history = self.stats_history[-100:]
            
            update_report_with_results(self.report_id, current_stats, self.gui_url)
            
            # 通过WebSocket广播更新 - 使用同步方式
            try:
                self.sync_broadcast_update(current_stats)
            except Exception as e:
                print(f"WebSocket广播失败: {e}")
            
        except Exception as e:
            print(f"实时更新失败: {e}")

    def collect_current_stats(self, environment):
        from performanceengine.main import get_host_ip
        """收集当前统计数据"""
        stats = {'detailed_stats': {}, 'total': {}}
        
        if hasattr(environment, 'stats') and environment.stats:
            # 收集总体统计数据
            total = environment.stats.total
            
            # 计算准确的错误率
            error_rate = 0
            if total.num_requests > 0:
                error_rate = (total.num_failures / total.num_requests) * 100
            
            # 使用安全的用户数获取方法
            current_users = self._get_safe_user_count(environment) or 0
            
            # Calculate current fails per second if needed
            current_fails_per_sec = 0
            if hasattr(total, 'current_fail_per_sec'):
                current_fails_per_sec = total.current_fail_per_sec
            elif hasattr(total, 'num_failures') and hasattr(total, 'start_time'):
                # Calculate fails per second manually if not available
                elapsed_time = (timezone.now() - self.start_time).total_seconds()
                if elapsed_time > 0:
                    current_fails_per_sec = total.num_failures / elapsed_time
            
            stats['total'] = {
                'num_requests': total.num_requests,
                'num_failures': total.num_failures,
                'avg_response_time': total.avg_response_time,
                'min_response_time': total.min_response_time,
                'max_response_time': total.max_response_time,
                'median_response_time': total.median_response_time,
                'current_rps': total.current_rps,
                'current_fails_per_sec': current_fails_per_sec,
                # 直接使用Locust的百分位数方法，确保一致性
                'p50_response_time': total.get_response_time_percentile(0.50),
                'p90_response_time': total.get_response_time_percentile(0.90),
                'p95_response_time': total.get_response_time_percentile(0.95),
                'p99_response_time': total.get_response_time_percentile(0.99),
                'current_users': current_users,
                'current_tps': total.current_rps,  # 保持TPS与RPS一致
                'error_rate': error_rate,
                'elapsed_time': (timezone.now() - self.start_time).total_seconds() if self.start_time else 0
            }
            
            # 添加运行时间
            if hasattr(self, 'start_time'):
                elapsed_time = (timezone.now() - self.start_time).total_seconds()
                stats['total']['elapsed_time'] = elapsed_time
            
            # 收集详细的接口统计数据
            for stat_name, stat_entry in environment.stats.entries.items():
                if stat_name != ("", ""):  # 跳过总体统计
                    name, method = stat_name
                    # 使用与Locust GUI相同的key格式：method + " " + name
                    key = f"{method} {name}" if method else name
                    
                    stats['detailed_stats'][key] = {
                        'name': name,  # 接口路径
                        'method': method,  # HTTP方法
                        'path': name,  # 添加path字段以便前端识别
                        'num_requests': stat_entry.num_requests,
                        'num_failures': stat_entry.num_failures,
                        'success_requests': stat_entry.num_requests - stat_entry.num_failures,
                        'avg_response_time': stat_entry.avg_response_time,
                        'min_response_time': stat_entry.min_response_time,
                        'max_response_time': stat_entry.max_response_time,
                        'median_response_time': stat_entry.median_response_time,
                        # 直接使用Locust的百分位数方法，确保与GUI一致
                        'p50_response_time': stat_entry.get_response_time_percentile(0.50),
                        'p90_response_time': stat_entry.get_response_time_percentile(0.90),
                        'p95_response_time': stat_entry.get_response_time_percentile(0.95),
                        'p99_response_time': stat_entry.get_response_time_percentile(0.99),
                        'current_rps': stat_entry.current_rps,
                        'error_rate': (stat_entry.num_failures / max(stat_entry.num_requests, 1)) * 100,
                        'current_users': current_users,  # 每个接口使用相同的用户数
                        'current_tps': stat_entry.current_rps  # 接口级别的TPS等于RPS
                    }
        
        # 添加时间戳
        stats['timestamp'] = timezone.now().isoformat()

        self.gui_url = f"http://{get_host_ip()}:{str(item.get('web_port', 8089))}"
        return stats



    def sync_broadcast_update(self, stats):
        """同步方式广播更新到WebSocket"""
        try:
            from performance.models import TaskReport
            from channels.layers import get_channel_layer
            from asgiref.sync import async_to_sync
            
            # 获取任务ID
            if self.report_id:
                try:
                    report = TaskReport.objects.get(id=self.report_id)
                    task_id = report.task_id
                    
                    # 使用channel layer同步发送消息
                    channel_layer = get_channel_layer()
                    if channel_layer:
                        group_name = f"performance_monitor_{task_id}"
                        async_to_sync(channel_layer.group_send)(
                            group_name,
                            {
                                "type": "performance_update",
                                "data": stats,
                                "timestamp": timezone.now().isoformat()
                            }
                        )
                        
                        # 同时发送到报告频道
                        report_group_name = f"performance_report_{self.report_id}"
                        async_to_sync(channel_layer.group_send)(
                            report_group_name,
                            {
                                "type": "performance_update",
                                "data": stats,
                                "timestamp": timezone.now().isoformat()
                            }
                        )
                except Exception as e:
                    print(f"同步广播失败: {e}")
        except Exception as e:
            print(f"WebSocket广播初始化失败: {e}")


# 全局监控实例
monitor = None


def setup_monitoring(report_id):
    """设置监控"""
    global monitor
    monitor = RealTimeMonitor(report_id)
    
    # 调用setup方法注册所有事件监听器
    monitor.setup()
    
    # 不再需要手动注册这些事件，setup方法已经注册了所有需要的事件
    # events.test_start.add_listener(monitor.on_test_start)
    # events.test_stop.add_listener(monitor.on_test_stop)
    # events.request.add_listener(monitor.on_request)
    
    return monitor


try:
    global_func = importlib.import_module('global_func')
except ModuleNotFoundError:
    from apitestengine.core import tools as global_func

# 使用统一的Django设置函数
setup_django()




# 集合点逻辑实现
# 获取Django模型 - 延迟获取，避免初始化问题
def get_test_env_model():
    models = get_models()
    return models['TestEnv']

from performanceengine.params import load_data
all_locusts_spawned = Semaphore()
all_locusts_spawned.acquire()  # 阻塞线程

def on_hatch_complete(user_count, **kwargs):
    """
    Select_task类的钩子方法
    """
    # 创建钩子方法
    all_locusts_spawned.release()
events.spawning_complete.add_listener(on_hatch_complete)

n = 0

def get_env_config(env_id, debug=True):
    """获取测试环境的配置 - 使用线程安全的版本"""
    return safe_get_env_config(env_id, debug)



class BaseEnv(dict):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def __setattr__(self, key, value):
        super().__setitem__(key, value)

    def __getattr__(self, item):
        return super().__getitem__(item)
ENV = BaseEnv()
# 不同场景存在不同环境
class SubEnvironment(Environment):
    def __init__(self):
        super().__init__()
        self.url = None

class BaseStepDispose():
    @classmethod
    def setUpClass(cls) -> None:
        cls.env = ENV

    def save_global_variable(self, name, value):
        ENV[name] = value

    def api_perform(self, step):

        try:
            data = copy.deepcopy(step).get('content', {})
            # 执行前置脚本
            self.__run_setup_script(data)
            # 发送请求
            response = self.__send_locustRequest(data)
            # 执行后置脚本
            self.__run_teardown_script(response)
        except Exception as e:
            raise Exception(f"{e}")




    def __send_locustRequest(self, data):
        method = data.get('method', 'POST').upper()  # 统一转大写
        request_info = self.construction_headers_data(data)
        # 统一处理所有请求类型
        valid_methods = {'GET', 'POST', 'PUT', 'DELETE', 'HEAD', 'PATCH'}
        if method not in valid_methods:
            raise ValueError(f"Invalid HTTP method: {method}")

        try:
            # 动态调用请求方法（核心优化点）
            with self.client.request(
                    method=method,
                    **request_info,
                    catch_response=True
            ) as response:
                if 200 <= response.status_code < 300:
                    response.success()
                    # print(f"传参：{request_info}", f"成功响应: {response.text}")
                    return response
                else:
                    response.failure(f"状态码异常: {response.status_code}")
                    return None

        except Exception as e:
            # 统一异常处理
            print(f"请求失败: {str(e)}")
            return None

    def construction_headers_data(self, data):
        """处理请求数据"""
        # 获取请求头
        if ENV.get('headers'):
            data['headers'] = {**ENV.get('headers'), **data.get('headers')}
        # 替换用例数据中的变量
        for k, v in list(data.items()):
            if k in ['interface', "headers", 'request', 'file']:
                # 替换变量
                v = self.__parser_variable(v)
                data[k] = v
        # files字段文件上传处理的处理
        files = data.get('file')
        if files:
            if isinstance(files, dict):
                file_data = files.items()
            else:
                file_data = files
            field = []
            for name, file_info in file_data:
                # 判断是否时文件上传(获取文件类型和文件名)
                if len(file_info) == 3 and os.path.isfile(file_info[1]):
                    field.append([name, (file_info[0], open(file_info[1], 'rb'), file_info[2])])
                else:
                    field.append([name, file_info])
            form_data = MultipartEncoder(fields=field)
            data['headers']["Content-Type"] = form_data.content_type
            data['request']['data'] = form_data
            data['files'] = None
        else:
            pass
        # 组织requests 发送请求所需要的参数格式
        request_params = {}
        # requests请求所需的所有字段
        params_fields = ['url', 'method', 'params', 'data', 'json', 'files', 'headers', 'cookies', 'auth', 'timeout',
                         'allow_redirects', 'proxies', 'hooks', 'stream', 'verify', 'cert']
        for k, v in data['request'].items():
            if k in params_fields:
                request_params[k] = v

        request_params['url'] = data.get('url')
        request_params['name'] = data.get('name')  # 使用接口名称作为请求名称
        # 请求头
        request_params['headers'] = data['headers']
        return request_params

    def __run_teardown_script(self, response):
        """执行后置脚本"""
        self._hook_gen.send(response)
        delattr(self, '_hook_gen')
    def __run_setup_script(self, data):
        """执行前置脚本"""
        self._hook_gen = self.__run_script(data)
        next(self._hook_gen)
    def __run_script(test, data):
        # env = test.env
        setup_script = data.get('setup_script')
        if setup_script:
            try:
                exec(setup_script)
            except Exception as e:
                delattr(test, '_hook_gen')
                raise
        response = yield
        teardown_script = data.get('teardown_script')
        if teardown_script:
            try:
                exec(teardown_script)
            except AssertionError as e:
                raise e
            except Exception as e:
                raise
        yield

    def __parser_variable(self, data):
        """替换变量"""
        pattern = r'\{{(.+?)}}'
        old_data = data
        if isinstance(data, OrderedDict):
            data = dict(data)
        data = str(data)

        while re.search(pattern, data):
            res2 = re.search(pattern, data)
            item = res2.group()
            attr = res2.group(1)
            
            # 检查是否是 __INT__ 标记（前端标记为不带引号的变量）
            is_unquoted = False
            if attr.startswith('__INT__'):
                is_unquoted = True
                attr = attr[7:]  # 去掉 __INT__ 前缀，获取真正的变量名
            
            value = ENV.get(attr)
            if value is None:
                raise ValueError('变量引用错误：\n{}\n中的变量{},在当前运行环境中未找到'.format(
                    json.dumps(old_data, ensure_ascii=False, indent=2), attr)
                )
            
            # 处理不带引号的变量（{{name}} 形式）
            if is_unquoted:
                # 找到变量位置，包括外层的引号
                s = data.find(item)
                # 检查前后是否有引号，如果有则一起替换
                if s > 0 and data[s-1] == "'" and s + len(item) < len(data) and data[s + len(item)] == "'":
                    # 替换时包括引号，这样eval后就不是字符串了
                    data = data[:s-1] + str(value) + data[s + len(item) + 1:]
                else:
                    # 没有外层引号，直接替换
                    data = data.replace(item, str(value))
            # 处理普通变量（带引号的变量，"{{name}}" 形式）
            elif isinstance(value, Number):
                s = data.find(item)
                dd = data[s:s + len(item)]
                data = data.replace(dd, str(value))
            elif isinstance(value, str) and "'" in value:
                data = data.replace(item, value.replace("'", '"'))
            else:
                data = data.replace(item, str(value))
        return eval(data)
    def if_perform(self, case):
        pass

    def script_perform(self, case):
        pass


class GenerateTask():
    def __init__(self, control='20'):
        self.control = control
    def create_taskSet_class(self, item: list) -> list:
            task_class = []
            for idx, scence in enumerate(item):
                cls_name = scence.get('name') or 'Demo' + f'_{idx + 1}'
                # 安全转换场景权重
                raw_weight = scence.get('weight', 1)
                try:
                    weight = int(raw_weight) if raw_weight is not None else 1
                except (ValueError, TypeError):
                    weight = 1
                task_funcs = {}
                if scence.get('steps'):
                    for step_idx, step in enumerate(scence['steps']):
                        func_name = self.create_testTask_name(idx, step_idx, step)
                        test_method = self.create_testTask_func(step)
                        task_funcs[func_name] = test_method
                task_funcs['on_start'] = self.on_start
                task_funcs['on_stop'] = self.on_stop

                task_funcs['weight'] = weight
                cls = type(cls_name, (TaskSet,BaseStepDispose,), task_funcs)
                task_class.append(cls)

            return task_class

    def on_start(self):
        '''场景集初始化执行方法'''
        global n
        if self.control == '10':
            n += 1
            all_locusts_spawned.wait()  # 同步锁等待
    def on_stop(self):
        """场景集结束执行方法"""
        pass
    def create_testTask_name(self, case_index, step_index, step):
        """生成唯一的测试方法名，避免重复"""
        step_title = step.get('title', 'UnknownTitle').replace(' ', '_')
        testTask_name = f"testTask_{case_index + 1}_{step_index + 1}_{step_title}"
        return testTask_name
    def create_testTask_func(self, step):
        """创建测试方法，使用函数工厂来动态传递参数"""
        type = step.get('type')
        def test_method(self):
            if type == 'api':
                return self.api_perform(step)
            elif type == 'script':
                return self.script_perform(step)
            elif type == 'if':
                return self.if_perform(step)
            else:
                raise ValueError(f"不支持的步骤类型:{type}")
        if type == 'api':
            # 安全转换步骤权重
            raw_weight = step.get('weight', 1)
            try:
                safe_weight = int(raw_weight) if raw_weight is not None else 1
            except (ValueError, TypeError):
                safe_weight = 1
            if safe_weight <= 0:
                safe_weight = 1
            return task(safe_weight)(test_method)

        return test_method


class Config():
    def __init__(self, config,default_url = ""):
        self.config = config
        self.default_url = default_url
        self.write_config()

    def write_config(self):
        """写入配置文件"""
        # 修改配置以支持混合模式（自动启动+GUI界面）
        config_lines = ["headless = false\n", f"host = {self.default_url}\n", "autostart = true\n", "autoquit = 300\n",]
        if isinstance(self.config, dict):
            time_unit = None
            pressureMode = None
            for key, value in self.config.items():
                if key == 'timeUnit':
                    time_unit = value
                elif key == 'pressureMode':
                    pressureMode = value
                elif key == 'pressureConfig':
                    if pressureMode == '10':
                        config_lines.extend(self.process_pressure_config(value, time_unit))

                elif key == 'master_server':
                    # 处理主服务器配置
                    pass
                elif key == 'worker_servers':
                    # 处理工作服务器配置
                    pass

                elif key == 'logMode':
                    self.run_log(value)

        # 写入配置文件
        try:
            import os
            # 获取当前文件所在目录（performanceengine目录）
            current_dir = os.path.dirname(os.path.abspath(__file__))
            config_path = os.path.join(current_dir, 'locust.conf')
            with open(config_path, "w") as config_file:
                config_file.writelines(config_lines)
        except IOError as e:
            print(f"Error writing to file: {e}")
            # 错误处理重试处理

    def process_pressure_config(self, pressure_config, time_unit) -> list:
        """处理 pressureConfig 配置"""
        lines = []
        for key, value in pressure_config.items():
            if key == 'lastLong' and time_unit:
                lines.append(f"run-time = {value}{time_unit}\n")
            elif key == 'concurrencyNumber':
                lines.append(f"users = {value}\n")
            elif key == 'concurrencyStep':
                lines.append(f"spawn-rate = {value}\n")


        return lines

    def server(self, server_array: list):
        """启动locust性能压测的服务器配置"""
        pass
    def run_log(self, log_mode: int):
        """
        运行过程中日志输出粒度：
        0：关闭
        10：开启-全部日志
        20：开启-仅成功日志
        30：开启-仅失败日志
        """
        pass

    def config_file(self) -> dict:
        """配置项返回给其他模块使用"""
        config = self.config
        filter_config = {key: value for key, value in config.items() if key not in
                         ['id', 'serverNames','creator','create_time','modifier',
                          'update_time', 'name',  'isSetting', 'project',
                          'task', 'resource', ]}

        stages = []
        if filter_config.get('pressureMode', None) == '20':
            config = filter_config.get('pressureConfig')
            ladders = config.get('ladders')
            sum_duration = 0
            for stage in ladders:
                lastLong = stage.get('lastLong',0)
                duration = self.timeUnit(filter_config.get('timeUnit', 's'), lastLong)
                sum_duration += duration
                stages.append({
                    "duration": sum_duration,
                    "users": int(stage.get('concurrencyNumber',0)),
                    "spawn_rate": int(stage.get('concurrencyStep', 0)),
                })
            filter_config['stages'] = stages
            del filter_config['pressureConfig']

        # print(filter_config)
        return  filter_config


    def timeUnit(self, time_unit, value):
        """时间单位转换"""
        value = int(value)
        if time_unit == 's':
            return value
        elif time_unit == 'm':
            return value * 60
        elif time_unit == 'h':
            return value * 3600
        else:
            return value






item = load_data()
env_config = get_env_config(item.get('env'))
ENV = {**env_config.get('ENV', {})}
global_func_file = ENV.get('global_func', b'')
if global_func_file:
    exec(global_func_file, global_func.__dict__)

# 安全获取host，避免NoneType错误
env_env = env_config.get('ENV', {})
host = env_env.get('host', '') if env_env else ''
conf = Config(item.get('presetting', {}), host)
runConf = conf.config_file()
# 创建 GenerateTask 实例
generator = GenerateTask(runConf.get('control'))

cls_list = generator.create_taskSet_class(item.get('scenes'))

# 设置监控
report_id = item.get('report_id')
if report_id:
    setup_monitoring(report_id)


class CreateKitClass(FastHttpUser):
    tasks = cls_list
    if runConf.get('thinkTimeType') == '10':
        wait_time = constant(runConf.get('thinkTime')[0])
    else:
        wait_time = between(runConf.get('thinkTime')[0], runConf.get('thinkTime')[1])
        
    @events.test_start.add_listener
    def on_test_start(environment, **kwargs):
        print("开始执行任务集初始化动作")
        
    @events.test_stop.add_listener
    def on_test_stop(environment, **kwargs):
        print("结束执行任务集初始化动作")


if runConf.get('pressureMode') == '20':
    class StagesShapeWithCustomUsers(LoadTestShape):
        """
        A simply load test shape class that has different user and spawn_rate at
        different stages.

        Keyword arguments:
            stages -- A list of dicts, each representing a stage with the following keys:
                duration -- When this many seconds pass the test is advanced to the next stage #持续时间
                users -- Total user count 用户数
                spawn_rate -- Number of users to start/stop per second 每秒产生或停止的用户数
                user_classes -- 指定的任务让负载更精确的控制该任务，如果不指定就是随机的
                stop -- A boolean that can stop that test at a specific stage 要想在哪个阶段停止运行就设置该值

            stop_at_end -- Can be set to stop once all stages have run.
        """

        stages = runConf.get('stages',[])

        def tick(self):
            run_time = self.get_run_time()

            for stage in self.stages:
                if run_time < stage["duration"]:
                    try:
                        tick_data = (stage["users"], stage["spawn_rate"], stage["user_classes"])
                    except KeyError:
                        tick_data = (stage["users"], stage["spawn_rate"])
                    return tick_data

            return None


if __name__ == '__main__':
    # 导入get_host_ip获取本机IP
    from performanceengine.main import get_host_ip
    
    # 获取本机IP
    host_ip = get_host_ip()
    
    # 使用绝对路径
    script_path = os.path.abspath("performanceengine/taskGenerate.py")
    
    # 获取配置文件路径
    import os
    config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'locust.conf')
    
    # 从item中获取web_port，如果不存在则默认使用8089
    web_port = str(item.get('web_port', 8089))
    
    # 使用真实IP地址启动Locust
    gui_url = f"http://{host_ip}:{web_port}"
    # 设置监控
    report_id = item.get('report_id')
    if report_id:
        print(f"已启用实时监控，报告ID: {report_id}, GUI地址: {gui_url}")
    subprocess.run([
        "locust", 
        "-f", script_path, 
        "--autostart", 
        "--web-host", "0.0.0.0",  # 绑定所有网络接口
        "--web-port", web_port,
        "--config", config_path  # 使用配置文件
    ])

